// Copyright (c) 2020-present, INSPUR Co, Ltd. All rights reserved.
// This source code is licensed under Apache 2.0 License.
//
// Created by florian on 05.08.15.

#include <assert.h>
#include <algorithm>
#include <functional>
#include "art_n.h"
#include "pure_mem/epoche.h"
#include "art_tree.h"

namespace art_rowex {

Tree::Tree(LoadKeyFunction loadKey)
    : root(new N256(0, {})), loadKey_(loadKey) {}

Tree::~Tree() {
  N::deleteChildren(root);
  N::deleteNode(root);
}

void Tree::remove(const Key &k, TID tid) const {
restart:
  bool needRestart = false;

  N *node = nullptr;
  N *nextNode = root;
  N *parentNode = nullptr;
  uint8_t parentKey, nodeKey = 0;
  uint32_t level = 0;
  // bool optimisticPrefixMatch = false;

  while (true) {
    parentNode = node;
    parentKey = nodeKey;
    node = nextNode;
    auto v = node->getVersion();

    switch (checkPrefix(node, k, level)) {  // increases level
      case CheckPrefixResult::NoMatch:
        if (N::isObsolete(v) || !node->readUnlockOrRestart(v)) {
          goto restart;
        }
        return;
      case CheckPrefixResult::OptimisticMatch:
        // fallthrough
      case CheckPrefixResult::Match: {
        nodeKey = k[level];
        nextNode = N::getChild(nodeKey, node);

        if (nextNode == nullptr) {
          if (N::isObsolete(v) ||
              !node->readUnlockOrRestart(v)) {  // TODO benötigt??
            goto restart;
          }
          return;
        }
        if (N::isLeaf(nextNode)) {
          node->lockVersionOrRestart(v, needRestart);
          if (needRestart) goto restart;

          if (N::getLeaf(nextNode) != tid) {
            node->writeUnlock();
            return;
          }
          assert(parentNode == nullptr || node->getCount() != 1);
          if (node->getCount() == 2 && node != root) {
            // 1. check remaining entries
            N *secondNodeN;
            uint8_t secondNodeK;
            std::tie(secondNodeN, secondNodeK) =
                N::getSecondChild(node, nodeKey);
            if (N::isLeaf(secondNodeN)) {
              parentNode->writeLockOrRestart(needRestart);
              if (needRestart) {
                node->writeUnlock();
                goto restart;
              }

              // N::remove(node, k[level]); not necessary
              N::change(parentNode, parentKey, secondNodeN);

              parentNode->writeUnlock();
              node->writeUnlockObsolete();
              rocksdb::DeleteWhileNoRefs::getInstance()->markNodeForDeletion(
                  node);
            } else {
              uint64_t vChild = secondNodeN->getVersion();
              secondNodeN->lockVersionOrRestart(vChild, needRestart);
              if (needRestart) {
                node->writeUnlock();
                goto restart;
              }
              parentNode->writeLockOrRestart(needRestart);
              if (needRestart) {
                node->writeUnlock();
                secondNodeN->writeUnlock();
                goto restart;
              }

              // N::remove(node, k[level]); not necessary
              N::change(parentNode, parentKey, secondNodeN);
              secondNodeN->addPrefixBefore(node, secondNodeK);

              parentNode->writeUnlock();
              node->writeUnlockObsolete();
              rocksdb::DeleteWhileNoRefs::getInstance()->markNodeForDeletion(
                  node);
              secondNodeN->writeUnlock();
            }
          } else {
            N::removeAndUnlock(node, k[level], parentNode, parentKey,
                               needRestart);
            if (needRestart) goto restart;
          }
          return;
        }
        level++;
      }
    }
  }
}

typename Tree::CheckPrefixResult Tree::checkPrefix(N *n, const Key &k,
                                                   uint32_t &level) {
  if (k.getKeyLen() <= n->getLevel()) {
    return CheckPrefixResult::NoMatch;
  }
  Prefix p = n->getPrefi();
  if (p.prefixCount + level < n->getLevel()) {
    level = n->getLevel();
    return CheckPrefixResult::OptimisticMatch;
  }
  if (p.prefixCount > 0) {
    for (uint32_t i = ((level + p.prefixCount) - n->getLevel());
         i < std::min(p.prefixCount, maxStoredPrefixLength); ++i) {
      if (p.prefix[i] != k[level]) {
        return CheckPrefixResult::NoMatch;
      }
      ++level;
    }
    if (p.prefixCount > maxStoredPrefixLength) {
      level += p.prefixCount - maxStoredPrefixLength;
      return CheckPrefixResult::OptimisticMatch;
    }
  }
  return CheckPrefixResult::Match;
}

}  // namespace art_rowex
